{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to integrate Spring AI LLM streaming in Langgraph4j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "var userHomeDir = System.getProperty(\"user.home\");\n",
    "var localRespoUrl = \"file://\" + userHomeDir + \"/.m2/repository/\";\n",
    "var springaiVersion = \"1.0.0\";\n",
    "var langgraph4jVersion = \"1.6-SNAPSHOT\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remove installed package from Jupiter cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "rm -rf \\{userHomeDir}/Library/Jupyter/kernels/rapaio-jupyter-kernel/mima_cache/org/bsc/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%dependency /add-repo local \\{localRespoUrl} release|never snapshot|always\n",
    "// %dependency /list-repos\n",
    "%dependency /add org.slf4j:slf4j-jdk14:2.0.9\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-spring-ai:\\{langgraph4jVersion}\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-springai-agentexecutor:\\{langgraph4jVersion}\n",
    "%dependency /add org.springframework.ai:spring-ai-commons:\\{springaiVersion}\n",
    "%dependency /add org.springframework.ai:spring-ai-model:\\{springaiVersion}\n",
    "%dependency /add org.springframework.ai:spring-ai-client-chat:\\{springaiVersion}\n",
    "%dependency /add org.springframework.ai:spring-ai-ollama:\\{springaiVersion}\n",
    "%dependency /add org.springframework.ai:spring-ai-openai:\\{springaiVersion}\n",
    "%dependency /resolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize Logger**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "try( var file = new java.io.FileInputStream(\"./logging.properties\")) {\n",
    "    java.util.logging.LogManager.getLogManager().readConfiguration( file );\n",
    "}\n",
    "\n",
    "var log = org.slf4j.LoggerFactory.getLogger(\"llm-streaming\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to use StreamingChatGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Received: StreamingOutput{node=agent, state=null, chunk=Sure}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=,}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= here}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk='s}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= a}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= light}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= joke}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= for}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= you}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=:\n",
      "\n",
      "}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=Why}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= don}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk='t}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= scientists}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= trust}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= atoms}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=?\n",
      "\n",
      "}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=Because}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= they}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= make}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= up}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk= everything}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=!}\n",
      "Received: StreamingOutput{node=agent, state=null, chunk=}\n"
     ]
    }
   ],
   "source": [
    "import org.bsc.async.AsyncGenerator;\n",
    "import org.bsc.async.FlowGenerator;\n",
    "import org.bsc.langgraph4j.NodeOutput;\n",
    "import org.bsc.langgraph4j.StateGraph;\n",
    "import org.bsc.langgraph4j.action.AsyncNodeAction;\n",
    "import org.bsc.langgraph4j.action.EdgeAction;\n",
    "import org.bsc.langgraph4j.action.NodeAction;\n",
    "import org.bsc.langgraph4j.prebuilt.MessagesState;\n",
    "import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;\n",
    "import org.bsc.langgraph4j.spring.ai.generators.StreamingChatGenerator;\n",
    "import org.bsc.langgraph4j.spring.ai.serializer.std.SpringAIStateSerializer;\n",
    "import org.bsc.langgraph4j.spring.ai.tool.SpringAIToolService;\n",
    "import org.bsc.langgraph4j.streaming.StreamingOutput;\n",
    "import org.bsc.langgraph4j.utils.EdgeMappings;\n",
    "import org.reactivestreams.FlowAdapters;\n",
    "import org.reactivestreams.Publisher;\n",
    "import org.springframework.ai.chat.client.ChatClient;\n",
    "import org.springframework.ai.chat.messages.AssistantMessage;\n",
    "import org.springframework.ai.chat.messages.Message;\n",
    "import org.springframework.ai.chat.messages.MessageType;\n",
    "import org.springframework.ai.chat.messages.UserMessage;\n",
    "import org.springframework.ai.chat.model.ChatModel;\n",
    "import org.springframework.ai.chat.model.ChatResponse;\n",
    "import org.springframework.ai.model.tool.ToolCallingChatOptions;\n",
    "import org.springframework.ai.ollama.OllamaChatModel;\n",
    "import org.springframework.ai.ollama.api.OllamaApi;\n",
    "import org.springframework.ai.ollama.api.OllamaOptions;\n",
    "import org.springframework.ai.openai.OpenAiChatModel;\n",
    "import org.springframework.ai.openai.OpenAiChatOptions;\n",
    "import org.springframework.ai.openai.api.OpenAiApi;\n",
    "import org.springframework.ai.tool.annotation.Tool;\n",
    "import org.springframework.ai.tool.annotation.ToolParam;\n",
    "import org.springframework.ai.tool.function.FunctionToolCallback;\n",
    "import reactor.core.publisher.Flux;\n",
    "\n",
    "import java.util.List;\n",
    "import java.util.Map;\n",
    "import java.util.concurrent.CompletableFuture;\n",
    "import java.util.concurrent.Flow;\n",
    "import java.util.concurrent.atomic.AtomicReference;\n",
    "import java.util.function.Consumer;\n",
    "import java.util.function.Function;\n",
    "import java.util.function.Supplier;\n",
    "import java.util.function.UnaryOperator;\n",
    "\n",
    "import static java.util.Optional.ofNullable;\n",
    "import static org.bsc.langgraph4j.StateGraph.END;\n",
    "import static org.bsc.langgraph4j.StateGraph.START;\n",
    "import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;\n",
    "import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;\n",
    "\n",
    "enum AiModel {\n",
    "\n",
    "        OPENAI_GPT_4O_MINI(\n",
    "                OpenAiChatModel.builder()\n",
    "                        .openAiApi(OpenAiApi.builder()\n",
    "                                .baseUrl(\"https://api.openai.com\")\n",
    "                                .apiKey(System.getenv(\"OPENAI_API_KEY\"))\n",
    "                                .build())\n",
    "                        .defaultOptions(OpenAiChatOptions.builder()\n",
    "                                .model(\"gpt-4o-mini\")\n",
    "                                .logprobs(false)\n",
    "                                .temperature(0.1)\n",
    "                                .build())\n",
    "                        .build()),        \n",
    "        OLLAMA_QWEN2_5_7B(\n",
    "            OllamaChatModel.builder()\n",
    "                    .ollamaApi( OllamaApi.builder().baseUrl(\"http://localhost:11434\").build() )\n",
    "                    .defaultOptions(OllamaOptions.builder()\n",
    "                            .model(\"qwen2.5:7b\")\n",
    "                            .temperature(0.1)\n",
    "                            .build())\n",
    "                    .build());\n",
    "    ;\n",
    "\n",
    "    public final ChatModel model;\n",
    "\n",
    "    AiModel(  ChatModel model ) {\n",
    "        this.model = model;\n",
    "    }\n",
    "}\n",
    "\n",
    "var chatClient = ChatClient.builder(AiModel.OLLAMA_QWEN2_5_7B.model)\n",
    "        .defaultOptions(ToolCallingChatOptions.builder()\n",
    "                .internalToolExecutionEnabled(false) // Disable automatic tool execution\n",
    "                .build())\n",
    "        .defaultSystem(\"You are a helpful AI Assistant answering questions.\" )\n",
    "        .build();\n",
    "\n",
    "var flux = chatClient.prompt()\n",
    "        .messages( new UserMessage(\"tell me a joke\"))\n",
    "        .stream()\n",
    "        .chatResponse()\n",
    "        ;\n",
    "\n",
    "var generator  = StreamingChatGenerator.builder()\n",
    "        .startingNode(\"agent\")\n",
    "        .mapResult( response -> Map.of( \"messages\", response.getResult().getOutput()))\n",
    "        .build(flux);        \n",
    "\n",
    "for( var item : generator ) {\n",
    "    System.out.println(\"Received: \" + item );\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use StreamingChatGenerator in Agent Executor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the agent's tools\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class WeatherTool {\n",
    "\n",
    "    @Tool( description = \"Get the weather in location\")\n",
    "    public String execQuery(@ToolParam( description = \"The query to use in your search.\") String query) {\n",
    "        // This is a placeholder for the actual implementation\n",
    "        return \"Cold, with a low of 13 degrees\";\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Agent executor "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "START \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent: ()\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "executeTools \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agent\n",
      "action\n",
      "agent: ()\n",
      "agent: (The)\n",
      "agent: ( weather)\n",
      "agent: ( in)\n",
      "agent: ( Napoli)\n",
      "agent: ( is)\n",
      "agent: ( currently)\n",
      "agent: ( cold)\n",
      "agent: (,)\n",
      "agent: ( with)\n",
      "agent: ( a)\n",
      "agent: ( low)\n",
      "agent: ( of)\n",
      "agent: ( )\n",
      "agent: (13)\n",
      "agent: ( degrees)\n",
      "agent: ( Celsius)\n",
      "agent: (.)\n",
      "agent: (null)\n",
      "agent\n",
      "__END__\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "result: The weather in Napoli is currently cold, with a low of 13 degrees Celsius. \n"
     ]
    }
   ],
   "source": [
    "import org.bsc.langgraph4j.spring.ai.agentexecutor.AgentExecutor;\n",
    "import org.bsc.langgraph4j.NodeOutput;\n",
    "import org.bsc.langgraph4j.streaming.StreamingOutput;\n",
    "import org.springframework.ai.chat.messages.AssistantMessage;\n",
    "import org.springframework.ai.chat.messages.UserMessage;\n",
    "import org.springframework.ai.chat.model.ChatModel;\n",
    "import org.springframework.ai.tool.ToolCallback;\n",
    "\n",
    "\n",
    "var agent = AgentExecutor.builder()\n",
    "                .streamingChatModel(AiModel.OPENAI_GPT_4O_MINI.model)\n",
    "                .toolsFromObject( new WeatherTool() )\n",
    "                .build()\n",
    "                .compile();\n",
    "\n",
    "var result = agent.stream( Map.of( \"messages\", new UserMessage(\"Weather in Napoli ?\") ));\n",
    "\n",
    "var state = result.stream()\n",
    "        .peek( s -> {\n",
    "                if( s instanceof StreamingOutput<?> sout ) {\n",
    "                        System.out.printf( \"%s: (%s)\\n\", sout.node(), sout.chunk());\n",
    "                }\n",
    "                else {\n",
    "                        System.out.println(s.node());\n",
    "                }\n",
    "        })\n",
    "        .reduce((a, b) -> b)\n",
    "        .map( NodeOutput::state)\n",
    "        .orElseThrow();\n",
    "\n",
    "log.info( \"result: {}\", state.lastMessage()\n",
    "                                .map(AssistantMessage.class::cast)\n",
    "                                .map(AssistantMessage::getText)\n",
    "                                .orElseThrow() );"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java (rjk 2.2.0)",
   "language": "java",
   "name": "rapaio-jupyter-kernel"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "java",
   "nbconvert_exporter": "script",
   "pygments_lexer": "java",
   "version": "22.0.2+9-70"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
